import numpy as np
import abc

from .graph import Graph, default_graph

class Node(object):
    '''
    计算图 节点 基类
    '''

    def __init__(self, *parents, **kargs):
        # 计算图对象，默认为全局对象default_graph
        self.graph = kargs.get('graph', default_graph)
        self.need_save = kargs.get('need_save', True)
        self.get_node_name(**kargs)

        self.parents = list(parents)    # 父节点列表，接受数量不定的其他Node类对象作为本节点的父节点
        self.children = []  # 子节点列表
        self.value = None   # 当前节点的值
        self.jacobi = None  # 结果节点对当前节点的雅可比矩阵

        # 将本节点添加到父节点的子节点列表中
        for parent in parents:
            parent.children.append(self)

        # 将本节点添加到计算图中
        self.graph.add_node(self)

    def get_parents(self):
        '''
        获取本节点的父节点列表
        :return:
        '''
        return self.parents

    def get_children(self):
        '''
        获取本节点的子节点列表
        :return:
        '''
        return self.children

    def get_node_name(self, **kargs):
        '''
        生成节点名称，如果用户不指定，则根据节点类型生成类似于MatMul:3的节点名，
        如果指定了name_scope，则生成类似Hidden/MatMul:3的节点名
        :param kargs:
        :return:
        '''
        self.name = kargs.get('name', '{}:{}'.format(self.__class__.__name__, self.graph.node_count()))
        if self.graph.name_scope:
            self.name = '{}/{}'.format(self.graph.name_scope, self.name)


    def forward(self):
        '''
        前向传播计算本节点的值，若父节点的值未被计算，则递归调用父节点的forward方法
        :return:
        '''
        for node in self.parents:
            if node.value is None:
                node.forward()
        self.compute()

    @abc.abstractmethod
    def compute(self):
       '''
       抽象方法，根据父节点的值计算本节点的值
       :return:
       '''

    @abc.abstractmethod
    def get_jacobi(self, parent):
        '''
        抽象方法，计算本节点对某个父节点的雅可比矩阵
        :param parent:
        :return:
        '''

    def backward(self, result):
        '''
        反向传播，计算结果节点对本节点的雅可比矩阵
        :param result:最终结果节点
        :return:
        '''
        # 先判断当前节点的雅可比矩阵是否已经被计算过了，若是，则不必重复计算
        if self.jacobi is None:
            # 如果当前节点就是结果节点，则只需要构造一个适当形状的单位矩阵即可作为雅可比矩阵
            if self is result:
                self.jacobi = np.mat(np.eye(self.dimension()))
            else:
                # 先构造一个适当形状的全零矩阵 作为 累加器
                self.jacobi = np.mat(np.zeros((result.dimension(), self.dimension())))

                for child in self.get_children():
                    if child.value is not None:
                        # 将结果节点对子节点的雅可比矩阵 与 子节点对父节点（当前节点）的雅可比矩阵相乘，并累加，得到最终父节点的雅可比矩阵
                        self.jacobi += child.backward(result) * child.get_jacobi(self)
        return self.jacobi

    def clear_jacobi(self):
        '''
        清空结果节点对本节点的雅可比矩阵
        :return:
        '''
        self.jacobi = None

    def dimension(self):
        '''
        返回本节点的值展开成向量后的维数
        :return:
        '''
        return self.value.shape[0] * self.value.shape[1]

    def shape(self):
        '''
        返回本节点的值作为矩阵的形状（行数，列数）
        :return:
        '''
        return self.value.shape

    def reset_value(self, recursive=True):
        '''
        重置本节点的值，并递归重置本节点的下游节点值
        :param recursive:
        :return:
        '''
        self.value = None
        if recursive:
            for child in self.get_children():
                child.reset_value()



class Variable(Node):
    '''
    变量节点
    '''
    def __init__(self, dim, init=False, trainable=True, **kargs):
        '''
        变量节点没有父节点，其构造函数接受变量的形状，是否初始化以及是否参与训练的标识
        :param dim:
        :param init:
        :param trainable:
        :param kargs:
        '''
        Node.__init__(self, **kargs)
        self.dim = dim

        # 如果需要初始化，则以正太分布随机初始化变量的值
        if init:
            self.value = np.mat(np.random.normal(0, 0.001, self.dim))

        # 变量节点是否参与训练
        self.trainable = trainable

    def set_value(self, value):
        '''
        为变量赋值
        :param value:
        :return:
        '''
        assert isinstance(value, np.matrix) and value.shape == self.dim

        # 本节点的值被改变，重置所有下游节点的值
        self.reset_value()
        self.value = value
